{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bark.generation import load_codec_model, generate_text_semantic\n",
    "from encodec.utils import convert_audio\n",
    "\n",
    "import torchaudio\n",
    "import torch\n",
    "\n",
    "device = 'cuda' # or 'cpu'\n",
    "model = load_codec_model(use_gpu=True if device == 'cuda' else False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
    "from hubert.hubert_manager import HuBERTManager\n",
    "hubert_manager = HuBERTManager()\n",
    "hubert_manager.make_sure_hubert_installed()\n",
    "hubert_manager.make_sure_tokenizer_installed()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer \n",
    "# Load HuBERT for semantic tokens\n",
    "from hubert.pre_kmeans_hubert import CustomHubert\n",
    "from hubert.customtokenizer import CustomTokenizer\n",
    "\n",
    "# Load the HuBERT model\n",
    "hubert_model = CustomHubert(checkpoint_path='data/models/hubert/hubert.pt').to(device)\n",
    "\n",
    "# Load the CustomTokenizer model\n",
    "tokenizer = CustomTokenizer.load_from_checkpoint('data/models/hubert/tokenizer.pth').to(device)  # Automatically uses the right layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and pre-process the audio waveform\n",
    "audio_filepath = 'audio.wav' # the audio you want to clone (under 13 seconds)\n",
    "wav, sr = torchaudio.load(audio_filepath)\n",
    "wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n",
    "wav = wav.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)\n",
    "semantic_tokens = tokenizer.get_token(semantic_vectors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract discrete codes from EnCodec\n",
    "with torch.no_grad():\n",
    "    encoded_frames = model.encode(wav.unsqueeze(0))\n",
    "codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze()  # [n_q, T]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# move codes to cpu\n",
    "codes = codes.cpu().numpy()\n",
    "# move semantic tokens to cpu\n",
    "semantic_tokens = semantic_tokens.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "voice_name = 'output' # whatever you want the name of the voice to be\n",
    "output_path = 'bark/assets/prompts/' + voice_name + '.npz'\n",
    "np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# That's it! Now you can head over to the generate.ipynb and use your voice_name for the 'history_prompt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Heres the generation stuff copy-pasted for convenience"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bark.api import generate_audio\n",
    "from transformers import BertTokenizer\n",
    "from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic\n",
    "\n",
    "# Enter your prompt and speaker here\n",
    "text_prompt = \"Hello, my name is Serpy. And, uh — and I like pizza. [laughs]\"\n",
    "voice_name = \"output\" # use your custom voice name here if you have one"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download and load all models\n",
    "preload_models(\n",
    "    text_use_gpu=True,\n",
    "    text_use_small=False,\n",
    "    coarse_use_gpu=True,\n",
    "    coarse_use_small=False,\n",
    "    fine_use_gpu=True,\n",
    "    fine_use_small=False,\n",
    "    codec_use_gpu=True,\n",
    "    force_reload=False,\n",
    "    path=\"models\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# simple generation\n",
    "audio_array = generate_audio(text_prompt, history_prompt=voice_name, text_temp=0.7, waveform_temp=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generation with more control\n",
    "x_semantic = generate_text_semantic(\n",
    "    text_prompt,\n",
    "    history_prompt=voice_name,\n",
    "    temp=0.7,\n",
    "    top_k=50,\n",
    "    top_p=0.95,\n",
    ")\n",
    "\n",
    "x_coarse_gen = generate_coarse(\n",
    "    x_semantic,\n",
    "    history_prompt=voice_name,\n",
    "    temp=0.7,\n",
    "    top_k=50,\n",
    "    top_p=0.95,\n",
    ")\n",
    "x_fine_gen = generate_fine(\n",
    "    x_coarse_gen,\n",
    "    history_prompt=voice_name,\n",
    "    temp=0.5,\n",
    ")\n",
    "audio_array = codec_decode(x_fine_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Audio\n",
    "# play audio\n",
    "Audio(audio_array, rate=SAMPLE_RATE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io.wavfile import write as write_wav\n",
    "# save audio\n",
    "filepath = \"/output/audio.wav\" # change this to your desired output path\n",
    "write_wav(filepath, SAMPLE_RATE, audio_array)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
